{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "62c80e4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "14477ff6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "211226b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d791a2a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: transformers in d:\\program files\\anaconda3\\lib\\site-packages (2.1.1)\n",
      "Requirement already satisfied: numpy in d:\\program files\\anaconda3\\lib\\site-packages (from transformers) (1.26.4)\n",
      "Requirement already satisfied: boto3 in d:\\program files\\anaconda3\\lib\\site-packages (from transformers) (1.24.28)\n",
      "Requirement already satisfied: requests in d:\\program files\\anaconda3\\lib\\site-packages (from transformers) (2.31.0)\n",
      "Requirement already satisfied: tqdm in d:\\program files\\anaconda3\\lib\\site-packages (from transformers) (4.65.0)\n",
      "Requirement already satisfied: regex in d:\\program files\\anaconda3\\lib\\site-packages (from transformers) (2022.7.9)\n",
      "Collecting sentencepiece (from transformers)\n",
      "  Obtaining dependency information for sentencepiece from https://files.pythonhosted.org/packages/a2/f6/587c62fd21fc988555b85351f50bbde43a51524caafd63bc69240ded14fd/sentencepiece-0.2.0-cp311-cp311-win_amd64.whl.metadata\n",
      "  Downloading sentencepiece-0.2.0-cp311-cp311-win_amd64.whl.metadata (8.3 kB)\n",
      "Requirement already satisfied: sacremoses in d:\\program files\\anaconda3\\lib\\site-packages (from transformers) (0.0.43)\n",
      "Requirement already satisfied: botocore<1.28.0,>=1.27.28 in d:\\program files\\anaconda3\\lib\\site-packages (from boto3->transformers) (1.27.59)\n",
      "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in d:\\program files\\anaconda3\\lib\\site-packages (from boto3->transformers) (0.10.0)\n",
      "Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in d:\\program files\\anaconda3\\lib\\site-packages (from boto3->transformers) (0.6.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in d:\\program files\\anaconda3\\lib\\site-packages (from requests->transformers) (2.0.4)\n",
      "Requirement already satisfied: idna<4,>=2.5 in d:\\program files\\anaconda3\\lib\\site-packages (from requests->transformers) (3.4)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in d:\\program files\\anaconda3\\lib\\site-packages (from requests->transformers) (1.26.16)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in d:\\program files\\anaconda3\\lib\\site-packages (from requests->transformers) (2023.7.22)\n",
      "Requirement already satisfied: six in d:\\program files\\anaconda3\\lib\\site-packages (from sacremoses->transformers) (1.16.0)\n",
      "Requirement already satisfied: click in d:\\program files\\anaconda3\\lib\\site-packages (from sacremoses->transformers) (8.0.4)\n",
      "Requirement already satisfied: joblib in d:\\program files\\anaconda3\\lib\\site-packages (from sacremoses->transformers) (1.2.0)\n",
      "Requirement already satisfied: colorama in d:\\program files\\anaconda3\\lib\\site-packages (from tqdm->transformers) (0.4.6)\n",
      "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in d:\\program files\\anaconda3\\lib\\site-packages (from botocore<1.28.0,>=1.27.28->boto3->transformers) (2.8.2)\n",
      "Downloading sentencepiece-0.2.0-cp311-cp311-win_amd64.whl (991 kB)\n",
      "   ---------------------------------------- 0.0/991.5 kB ? eta -:--:--\n",
      "   - ------------------------------------- 41.0/991.5 kB 991.0 kB/s eta 0:00:01\n",
      "   --- ----------------------------------- 92.2/991.5 kB 871.5 kB/s eta 0:00:02\n",
      "   ----- -------------------------------- 143.4/991.5 kB 944.1 kB/s eta 0:00:01\n",
      "   -------- ------------------------------- 204.8/991.5 kB 1.2 MB/s eta 0:00:01\n",
      "   ------------------- -------------------- 471.0/991.5 kB 2.0 MB/s eta 0:00:01\n",
      "   -------------------------- ------------- 655.4/991.5 kB 2.3 MB/s eta 0:00:01\n",
      "   ---------------------------------------- 991.5/991.5 kB 3.0 MB/s eta 0:00:00\n",
      "Installing collected packages: sentencepiece\n",
      "Successfully installed sentencepiece-0.2.0\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "feca8cd8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2ec5e2ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "684cbf3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3fc369de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c1e4b94a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['精 益 创 业 实 战 ( 第 2 版 )', 'leaning entrepreneurship (version 2)']\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e03ced9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3dfc1853",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, loss=0.0471: 100%|█| 20/20 [13:11<00:00, 39.60s/it\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABGPElEQVR4nO3dd3hUZf7+8XvSQxICARIIhBp670awIChIEXt3Udf1iwuCYsW1t2DHimtZ1F0VfiuKrHSUpvQAgoCEDpJAKCEJhNQ5vz8wQyaZSSbJJGdm8n5dV66dU+dzcli4fc5znsdiGIYhAAAAD+RndgEAAADOEFQAAIDHIqgAAACPRVABAAAei6ACAAA8FkEFAAB4LIIKAADwWAFmF1AVVqtVKSkpioiIkMViMbscAADgAsMwlJWVpdjYWPn5ld1m4tVBJSUlRXFxcWaXAQAAKuHQoUNq1qxZmft4dVCJiIiQdO5C69ata3I1AADAFZmZmYqLi7P9O14Wrw4qRY976tatS1ABAMDLuNJtg860AADAYxFUAACAxyKoAAAAj0VQAQAAHougAgAAPBZBBQAAeCyCCgAA8FgEFQAA4LEIKgAAwGMRVAAAgMciqAAAAI9FUAEAAB7LqyclrC5n8wp1MjtPgX4WRdcNMbscAABqLVpUHFi47YgGTPlJk/7fr2aXAgBArUZQccDf79y00/mFVpMrAQCgdiOoOBDofy6oFFoNkysBAKB2I6g44O937teST1ABAMBUBBUHAmwtKjz6AQDATAQVBwzjXEvKb4czTa4EAIDajaDiwKrdJ8wuAQAAyOSgUlBQoCeffFKtWrVSaGioWrdureeff15Wkx+5XN2zqanfDwAAzjF1wLdXXnlFH374oT7//HN17txZGzZs0F133aXIyEhNnDjRtLrq1Qk07bsBAMB5pgaV1atXa/To0RoxYoQkqWXLlvr666+1YcMGh/vn5uYqNzfXtpyZWT19SAL8zjc0nc0rVGiQf7V8DwAAKJupj34GDhyoH3/8UcnJyZKkX3/9VT///LOGDx/ucP/ExERFRkbafuLi4qqlrgbhQbbPuQWF1fIdAACgfKa2qDz22GPKyMhQhw4d5O/vr8LCQr300ku65ZZbHO4/efJkTZo0ybacmZlZLWEl4M+RaSUGfQMAwEymBpWZM2fqP//5j7766it17txZmzdv1gMPPKDY2FiNGTOm1P7BwcEKDg6u9rosFossFskwpEKDoAIAgFlMDSqPPPKIHn/8cd18882SpK5du+rAgQNKTEx0GFRqkr/FogLDEGO+AQBgHlP7qGRnZ8vPz74Ef39/019PliS/Px//0KICAIB5TG1RGTVqlF566SU1b95cnTt31qZNm/Tmm2/q7rvvNrMsSef6qeRJstJHBQAA05gaVN5991099dRT+vvf/660tDTFxsbq//7v//T000+bWZakc49+JKmAoAIAgGlMDSoRERGaOnWqpk6damYZDv2ZU2zz/gAAgJrHXD9OFPVRoUEFAADzEFSc8PuzSYUWFQAAzENQcaJozDdaVAAAMA9BxQmLpejRD0kFAACzEFScKGpRYQh9AADMQ1Bx4nwfFZMLAQCgFiOoOOHHox8AAExHUHGiaGR/ggoAAOYhqDhxvkXF5EIAAKjFCCpOMI4KAADmI6g4YWEcFQAATEdQcYLOtAAAmI+g4sT5kWkJKgAAmIWg4gTjqAAAYD6CihMMoQ8AgPkIKk4wKSEAAOYjqDhBZ1oAAMxHUHGiqEWFcVQAADAPQcUJWx8Vq8mFAABQixFUnOD1ZAAAzEdQcYK5fgAAMB9BxQnm+gEAwHwEFSeY6wcAAPMRVJzg9WQAAMxHUHHC78/fDEEFAADzEFScYK4fAADMR1Bxgrl+AAAwH0HFCeb6AQDAfAQVJ+hMCwCA+QgqTjDXDwAA5iOoOGFhZFoAAExHUHGiqEWlkKQCAIBpCCpOMIQ+AADmI6g44efHox8AAMxGUHGCt34AADAfQcUJxlEBAMB8BBUn6KMCAID5CCpOWGwtKgQVAADMQlBxwo9xVAAAMB1BxQk/WlQAADAdQcWJ831UTC4EAIBajKDiRMGfz3zyCqwmVwIAQO1FUHHim6Q/JElv/7jL5EoAAKi9CCoAAMBjEVSc6NuyviRpcIdokysBAKD2Iqg40TwqTBJv/QAAYCaCihOzNp7ro7J05zGTKwEAoPYiqAAAAI9FUAEAAB6LoAIAADwWQcWJV6/rZnYJAADUegQVJxrVDZYkdW0aaXIlAADUXgQVJwL+nJUwLSvH5EoAAKi9CCpO+P85KeHRzFxtS8kwuRoAAGongooT/n+2qEjSfzf8YWIlAADUXgQVJ3KZNRkAANMRVJzILzwfVJYnMzotAABmIKg4ERRw/lez7/gZEysBAKD2Iqg4YZGl/J0AAEC1Iqg4UWCljwoAAGYjqDjRrH4du+X1+0+aVAkAALUXQcWJ+Ohwu+XEeTtMqgQAgNqLoOIiq2F2BQAA1D4ElTL0bVnf9nnzoVM6nVtgYjUAANQ+BJUyrN+fbrfc76UlJlUCAEDtRFCpgOy8QrNLAACgViGoVJBh0FkFAICaQlCpoFaT5ykzJ9/sMgAAqBUIKpXAbMoAANQMgkolvPDDdrNLAACgVjA9qBw+fFi33367GjRooDp16qhHjx5KSkoyuyxJ0v9d0trsEgAAqNVMDSrp6ekaMGCAAgMDNX/+fG3fvl1vvPGG6tWrZ2ZZNhfFNzK7BAAAarUAM7/8lVdeUVxcnKZPn25b17JlS/MKKmFAfAOzSwAAoFYztUVlzpw56tOnj2644QZFR0erZ8+e+vjjj53un5ubq8zMTLuf6mSxWJxuG/Lmcm08mO50OwAAqDpTg8revXs1bdo0tW3bVgsXLtTYsWM1YcIEffHFFw73T0xMVGRkpO0nLi6uhis+b3faad328VrTvh8AgNrAYpg4gllQUJD69OmjVatW2dZNmDBB69ev1+rVq0vtn5ubq9zcXNtyZmam4uLilJGRobp161ZLjZe/uVy70k473b5/yohq+V4AAHxVZmamIiMjXfr329QWlSZNmqhTp0526zp27KiDBw863D84OFh169a1+6luXZtGlrmdiQoBAKg+pgaVAQMGaOfOnXbrkpOT1aJFC5MqKu3JkZ3K3L5u34kaqgQAgNrH1KDy4IMPas2aNXr55Ze1e/duffXVV/roo480btw4M8uyExUWVOb2kg/O1u8/qVs/XqNdR7OqsSoAAGoHU4NK37599d133+nrr79Wly5d9MILL2jq1Km67bbbzCyrQkoGlRs+XK1Ve07ojk/XmVMQAAA+xNRxVCRp5MiRGjlypNlllKltdLjTDrXOeiIfycyR1WrIz8/5K84AAKBspg+h7w26x9Vzum1TGWOp5BZYq6EaAABqD4KKCxJaOx+h9oNle7R0Z5rDbWlZOdVVEgAAtQJBxQUjujUpc/td09c7XJ9fWPrB0Nm8QqWfyXNLXQAA+DqCigtCAv0rdZy/g/4pPZ5fpJ4vLNapbMIKAADlIai4ydm8wlLrHHWjLeq3si2leucpAgDAFxBU3OT6D1cpt8A+rNz35Uan+wcF8KsHAKA8/GvpJttSMnXLR2vs1u1Idd5q4lfGzMwAAOAcgoobbTx4qsztxed/XMvQ+wAAlIug4qJ//7WfHh3WvkrnKLSeDyoWhz1YAABAcQQVF13UtpH+fml8lc5R/HXl0EB+9QAAlId/LWtQXuH5kWor+8ozAAC1ielz/dQWC347oo9W7LEtE1QAACgfQaWGjP1Pkt1yCI9+AAAoF/9amiQ4gBYVAADKQ1AxCQO+AQBQPv61NEmAg3mAAACAPYJKDfhhS0qpdY4mLAQAAPYIKhX02V19K7T/b4czNP6rTaXWGw72BQAA9ggqFXRp++gK7T91yS6H661WogoAAOUhqFSzJTuOOlxPTgEAoHwElUp4amQn9W1Zv0rnKD5BIQAAcIygUgl/HdhK/x17YZXOQYsKAADlI6iYxEqLCgAA5SKomISgAgBA+QgqJiGnAABQPoJKFfzzjt6VPpYWFQAAykdQqYKhnRvrpWu6VOpYcgoAAOUjqFRRq4ZhlTqOFhUAAMpHUKkiP0vl5uzh9WQAAMpHUKmiygYVBnwDAKB8BJUq8nfwG3RlYmRaVAAAKB9BpYosDlpUZo8bUO5x9FEBAKB8AWYX4O1KPvrZlzjcYXgpiaACAED5aFGpoj1pp+2WXQkpEq8nAwDgCoJKFR3NyqnUcbSoAABQPoKKSehMCwBA+QgqJqFFBQCA8hFUqqh43risQ3QFDnR/LQAA+BqCShW1aRRu+xzgygAqf6JFBQCA8hFUqqh4K0qjiGCXjyskqAAAUC6CShUFBZz/FQ7t3Njl48gpAACUj6DiBl/e019Pjuioi9o2dLj97gGtSq3j0Q8AAOUjqLjBgPiGuuei1k4He5s8vINm3HuB3Tor7ycDAFAugko1aRdzvpOtn8WiJpEhdtvJKQAAlI+gUk1m3Jtg++xnkSyyb23h0Q8AAOVjUsJqEhUWpJWPDlJQgJ8sFotKPhUqL6fkF1qVlpWrpvVCq69IAAA8HC0q1Sguqo5i6oY43FZei8ptn6zVgCk/adWe43br8wqsyiuwuq1GAAA8GUGlhhR/jVmSDqVn68Uftmvf8TMO91+376Qk6au1B23rrFZDCYk/qs+Li1VQSFgBAPg+gkoNiY4I1l8SWtiW/7PmoD75eZ8Gvb6szOOKt7xk5RboxJk8ZeYUaFtKpt5cnKwjGZWbvRkAAG9AH5UaYrFY9PzoLsovtOrrdYdcPq6w+OtBxT7+9fMNOn46Vz/uOKq5Ey5yY6UAAHgOWlRqmLOxVpxx9oTn+OlcSdK2lMyqlgQAgMciqNQwR/MWpmacLeMIo9gnXmkGANQuBJUa5uegReXFH3bYPk/+dqvu+Xx9sa3n9k8/k8cgcQCAWoc+KjXMUVDJzMnXB8t267IO0fp63cFS25MOnNR101arf6uomigRAACPQVDxACt3HdfKXcf16oKdpbZZLNKTs7dJktb++cpySbkFhQoO8K/WGgEAMAOPfmrYD1tSK7T/4u1HtSO17A6zvV9Yopz8wqqUBQCARyKo1LCit3Xc6XRuQblhBgAAb0RQAQAAHoug4iNSGaEWAOCDCCo+Yv5vR8wuAQAAtyOo+Ij//ZpidgkAALhdpYLK559/rrlz59qWH330UdWrV08XXnihDhw44LbiAABA7VapoPLyyy8rNDRUkrR69Wq99957evXVV9WwYUM9+OCDbi0QrjuVnWd2CQAAuFWlgsqhQ4cUHx8vSZo9e7auv/563XvvvUpMTNTKlSvdWiBcd+20VWaXAACAW1UqqISHh+vEiROSpEWLFmnIkCGSpJCQEJ09W9YEe6hOe4+dMbsEAADcqlJD6F9++eW655571LNnTyUnJ2vEiBGSpG3btqlly5burA8AANRilWpRef/995WQkKBjx45p1qxZatCggSQpKSlJt9xyi1sLBAAAtVelWlTq1aun9957r9T65557rsoFAQAAFKlUi8qCBQv0888/25bff/999ejRQ7feeqvS09PdVhwAAKjdKhVUHnnkEWVmnpsEb+vWrXrooYc0fPhw7d27V5MmTXJrgb7my3v6V+v5VyQfq9bzAwBQkyoVVPbt26dOnTpJkmbNmqWRI0fq5Zdf1gcffKD58+e7tUBf07tF/Wo9/1/+ta5azw8AQE2qVFAJCgpSdna2JGnJkiW64oorJElRUVG2lpaKSkxMlMVi0QMPPFCp471FSKC/2SUAAOA1KtWZduDAgZo0aZIGDBigdevWaebMmZKk5ORkNWvWrMLnW79+vT766CN169atMuUAAAAfVakWlffee08BAQH65ptvNG3aNDVt2lSSNH/+fA0bNqxC5zp9+rRuu+02ffzxx6pfv3ofiwAAAO9SqRaV5s2b64cffii1/q233qrwucaNG6cRI0ZoyJAhevHFF8vcNzc3V7m5ubblyj5mAgAA3qFSQUWSCgsLNXv2bO3YsUMWi0UdO3bU6NGj5e/veh+MGTNmaOPGjVq/fr1L+ycmJjJWCwAAtUilgsru3bs1fPhwHT58WO3bt5dhGEpOTlZcXJzmzp2rNm3alHuOQ4cOaeLEiVq0aJFCQkJc+t7Jkyfbvf6cmZmpuLi4ylwCAADwApXqozJhwgS1adNGhw4d0saNG7Vp0yYdPHhQrVq10oQJE1w6R1JSktLS0tS7d28FBAQoICBAy5cv1zvvvKOAgAAVFhaWOiY4OFh169a1+/EVDw5pZ3YJAAB4nEq1qCxfvlxr1qxRVFSUbV2DBg00ZcoUDRgwwKVzDB48WFu3brVbd9ddd6lDhw567LHHKvQIyRcYMswuAQAAj1OpoBIcHKysrKxS60+fPq2goCCXzhEREaEuXbrYrQsLC1ODBg1Kra8NLLKYXQIAAB6nUo9+Ro4cqXvvvVdr166VYRgyDENr1qzR2LFjddVVV7m7xlrBQk4BAKCUSrWovPPOOxozZowSEhIUGBgoScrPz9fo0aM1derUShezbNmySh/r7cgpAACUVqmgUq9ePX3//ffavXu3duzYIcMw1KlTJ8XHx7u7Pp/WJDJEz4zqrF7N62nm+kNmlwMAgMdxOaiUNyty8daQN998s9IF1SYf/6WPujSNlCSFBtWuzsMAALjC5aCyadMml/az0NmiXBufulzHsnLVvnGEbV2gf6W6CwEA4NNcDipLly6tzjpqlaiwIEWF2b8dFRkaaFI1AAB4rkoPoQ/3GtmtiVbuOq7+raK0bv9JfZP0h9klAQBgOoKKhwjw99MbN3aXJK3ff9LkagAA8Ax0jPBAYy5saXYJAAB4BIKKB+rSNFKbn77c4bZxg8qf8BEAAF9BUPFQ9eqUnorg4naN9MjQDi4df8ena9Xy8bk6nVvg7tIAAKgxBBUvcnHbhi7tZxiGVu46Lkl6evZvVf7e9DN5SsvKqfJ5AACoKIKKF3F1jBprsYmYdx4tPXlkucdbDZ3NK7R97vnCYvV76UfbOgAAagpBxYNd1T3WbvlUdp5LxxUWSyrFQ4urxkxfp45PL1BaZo4Kip3gSCatKgCAmkVQ8WBv39xDO54fZlueuyVVkjTj3gs0ML6hggIc3z6rcT5cFBRa7bbtSM1URna+kg6k67LXl2nZzrRSxxc9Npq+ar/duQAAqGmMo+LBLBaLQgLPh5HCP0PDBa0b6ILWDWQYhlpNnlfquOItKrvSTsv487htKZka+e7PahIZopNn8pRbYNWd09dr/5QRks71bRn/9fmpEpb+nqb7L2OiSQCAeQgqHq54v5SSrRvO+qwUltjv0teX6cCJbNtyaobjRziHTp61tdpI5wJRYWWeHQEA4CYEFS9itZa/jyQZJfYrHlLKPH+JgPPZqv36bNV+174UAIBqQB8VL+Jqf5GSLSrlOXTyXJDx9yv7rSKD/ioAgBpGUPEiBS4+hnno/22u0HkvenWpft51XOW9/cxTIABATePRjxfxLydJ+FnOhYmlO49V+NxfrzuoVg3DytyHFhUAQE0jqHiRr/7Wv9S6QH+L8gvPBYiqtHicPJOnuVtTy9ynoo+UAACoKoKKF9g/ZYRyCwoVHOBfaptFFklVDxCr954odx/eAAIA1DT6qHgJRyGlprn61hEAAO5CUPF2rk3/4xY8+gEA1DSCiperwZyiFckV76QLAEBVEFS83MQhbSVJN/RuVu3fxeBvAICaRmdaL3ffJW00uEOM4qPD9d+kP6r1u5igEABQ0wgqXs5isah944ga+S7e+gEA1DQe/cBlVoIKAKCGEVTgsnyCCgCghhFU4LK8AgZSAQDULIIKAADwWAQVuGxMQguzSwAA1DIEFbisZTmzKwMA4G4EFbiMvrQAgJpGUIHLqvp6ckGhVadzC9xUDQCgNiCowGVVHZn2yrdXqsszC3XyTJ6bKgIA+DqCig8Z0bVJtZ6/qrMn70o7LUn6efdxd5QDAKgFCCo+ZEB8w2o9v7um+vGrySmfAQBejaDiQyzVHADcNdePRSQVAIBrCCo+pLpbKtw1e3J1ByoAgO8gqPgQSzUnACYlBADUNIKKD+vTor5bz1fVzrRF6KMCAHAVQcWHGCWCRGy90Eqfa8fzw/TkiI5266rSoGJfG0kFAOAagooP6dI00vb55r5xSs92PF5Jj7h6+uC2XmWeKzTIX51i69qtq8qjn+KH0kcFAOCqALMLgPt0jo3UV/f0V0xkiNo0ClfLx+c63K9jkwi1iwkv93xB/vY5tiqdaQusVttncgoAwFW0qPiYC+Mbqk2jcyGkX8sou21zxg/QXwe20uThHdU8qvwJBns2t+/jUmh1sqMDhmEo+WiWcgsKlXQgXTn5xYIKTSoAABfRouLD2kSHad3+k7blbs3qqVuzepKkfBdSh7+fRRe2aaBVe05Ikn7efczl7/7P2oN6avZvtuUmkSG2z8QUAICraFHxYZ2a1HW6za+MVo1ezevZPk+9qYftc/LR00rNOOvSd3+4bI/dcmpGju3zPV9s0G+HM1RQaNV7P+1S0oF0l84JAKh9CCo+7JZ+zZ1uK6tVo3iIia4bYrctIfEnZeXkl/vdJd9AKunaaav09bqDen1Rsq6btqrc8wEAaieCig8L8Hd+e6vSTWT+b0dKrTuSkaPdaVm25fK63eYVWPXU99sqXwQAoFYgqNRSZXVoLS9kLE8u3VflgsQfNeTNFTqaee4Rj7smMAQA1G4EFR+3ZNLFGtU9VksfvrTUtpev6arggIr/EWhWYiC5s3mFts+7005LYqwUAIB7EFR8XHx0hN69padaNSz9OvKt/Ztr/sSLJEndmp0fLK5k/5Jb+9v3dbmkfSO75b9+vt722c9i0bNzttl1ngUAoLIIKrVc60bh2vjU5fr2vgtt60o+tSn59pB/ieaSoteXJWnFrmP6bNV+d5cJAKilGEcFigoLKnP7wZPZdstljaQ/rcRryQAAVAUtKiilcYlXkuf/lmq3XJWh9AEAqAiCCmw+v7ufLu8Uo+dGdy5zv+XJx2ydZgEAqE4EFdhc0q6RPv5LH0VHhJS530cr9mrIm8tty2FB/tVdGgCgliKooMrOFHs9GQAAd6IzLcpldTJ/4Ts/7lJ8dHjNFgMAqFUIKijXmbwCh+vfXJxcw5UAAGobHv2gXC9e3cXsEgAAtRRBBeVqHxNhdgkAgFqKoIJy+ft5z8Q9S39P06T/t1mncx0/rgIAeBf6qKBcAX7ek2fv+uzcvEONIoI1+cqOJlcDAKgq7/kXCKbx9/eeFpUiqaeYFBEAfAFBBeUK9KJHP0XO8OgHAHwCQQXlyjibb3YJFfbj72lmlwAAcAOCCsoVEsgQ+QAAcxBUUK6YumXP/eMOhptnZG7dMMyt5wMAmIOggnIFuLGPyrW9mjpcb3VvTlGP5vXce0IAgClMDSqJiYnq27evIiIiFB0drauvvlo7d+40syQ44FfFoDJ73IDz57I4PpfVzS0qFnlfB2AAQGmmBpXly5dr3LhxWrNmjRYvXqyCggJdccUVOnPmjJllwc16xNUrd5+KBpVfdh/XgRP8OQEAX2fqgG8LFiywW54+fbqio6OVlJSkiy++2KSqUJ0MQ7qmZ1N9t+lwqfWu+G7TH3pw5q+25f1TRjjcjxADAL7Bo/qoZGRkSJKioqIcbs/NzVVmZqbdDzzHnPEDSq27vFNMqXUD4huWWudqi0rxkFKWDQfStWT7UeUVWF3aHwDgmTwmqBiGoUmTJmngwIHq0sXxbL2JiYmKjIy0/cTFxdVwlbXX1T1iy9x+94BW6tasnj67q6/d+ndv6Wm3bLGcO9dNfeKUeG1X23p3d6aVpHu+2KCX5m53/4kBADXGY4LK+PHjtWXLFn399ddO95k8ebIyMjJsP4cOHarBCmu3qTf3LHP7db3Pvc1zaftofXBbL0nSiG5NSo3BYhhSgL+fXrm+m67r1cy23t2daYt8vvpAtZwXAFAzPGJSwvvvv19z5szRihUr1KxZM6f7BQcHKzg4uAYrQ3F3XNBC/15T/j/8w7s20erJlykmouzxV4q/TGRYz7WqbU/NVMsGYQoL9og/mgAAk5naomIYhsaPH69vv/1WP/30k1q1amVmOSjHUyM76au/9Xdp3yaRoQ5fazZ0vuWk+KvKhYahJTvSNOKdnzXq3Z9drqmg0Kq0rHMTEB46me3ycQAA72Dqf7aOGzdOX331lb7//ntFREToyJEjkqTIyEiFhoaaWRocCArw04VtSneElSo3bknxIGM1DH278Q9J0t7j59/YOZaVq6AAP0WGBjo8R+dnFiq3wKqwIH+dySuscA0AAM9maovKtGnTlJGRoUsvvVRNmjSx/cycOdPMslAJfpX8k1TUqPLP5XuUk28fNE7nFqjvS0vU/blFTo/P/fOtHmchpU0jhtIHAG9maouKu+d3Qc0YPyhe/9uSoo/u6KOhU1dIkuqGOG7xKKleaJDdsp/FokLD0Mcr99mt33PstAa/sdy2XGg1tLQSMyI7GwkXAOAd6LGICnt4aHs9PLS9JOnV67rpbH6hYuuV/ajurZu6a1bSYU0YHG+33lmMKB5SJKnAatU9X2yocK0EFQDwbgQVVMmNfV0by+aans10Tc/Sb3QVuDiAyvp96RWqqwg5BQC8m8eMowKU5fZP11bqOFpUAMC7EVTg0yrbyRcA4Bn4axw+jRYVAPBuBBX4NAtBBQC8GkEFPs3B4LiVZhiGzuQWuO+EAIByEVTg09zZnvLMnG3q/MxCbdh/0o1nBQCUhaACn+bOIQW/+HMm5reWJLvxrACAshBUgAqqzLxGAIDKIajAp1XHLA0/7z7u/pMCABwiqAAAAI9FUIHPuf+y8/MJhQb6u/38LRvUcfs5AQCOEVTgc9rGRKhJZIgk2f7XnaLruv+cAADHCCrwOSO7NtHdA1pJkqwV6KTy/tLdmvNrim35TG6B/r16v45m5tjtR1daAKg5BBX4HD8/i23WZFdjytY/MvTawp2a8PUm27pn5mzTU99v080frXF/kQAAlxBU4FOm39VX0vk5fqwuJpVTZ/NKrfsm6Q9J0r7jZ5RXYHVPgQCACiGowOP1axXl8r6D2kdLkq1FxdVHPwElplnOL7QPJu2enO9yDQAA9yGowOO1ahBmtxwZGmj3Zs+NfZpJki5u18i2zq+Cz34C/M/3PPlw+R69++OuSlYLAHCnALMLAMpzQ59mmrnhkG35wSFtFRUebFvu2by+/jG8kyJCzv9x9qtgi0rxDrJT5v9epXoBAO5DUIFHe/zKDmrZ8HyLyqz7EtSreX2lZpx/E6dvyyhF1gm0P9DWR8XFoMKrPADgkXj0A4829pI2Cgs6n6c7NYmUxWLR6dwC27rI0MBSxxW1qBiGlHQgXf/6eZ+MMkLLsazSnWkBAOajRQWm+mZsgjYcSNcNvZup94tLHO4TGuSv/45NkGGc+yxJQf7nM3bxRz5FTpw+FzwWbT+qRduPSpLO5hdq3KD4UvtK0tv0SQEAj0SLCkzVp2WUxl7SRg2K9Tkp0q/l+bd9+raMsnv7x74/SunnNu84CB6vLdwpSVq794QmzdysE6dzbdt2pGa6XPPafSdd3hcAUDW0qMBjxZQx/H2D8GC9eWN3BQX4KSigdN4ucDKAyrCpK/T7kSxJ0pq9J7Rq8uBK1ZaTX6iQaphHCABgjxYVeIzGJebQCfQvu4frtb2aaWS32Ap9R1FIkaSUjBzNXH9Qx7JyyzjCMUarBYCaQVCBx1jx6CC75YeuaF/pc7VwcYbjx2ZtVd+XHPeNKcvmQ6cqfAwAoOIIKvAYxR/hLJl0sZrWC630uSZd3s4dJQEATEYfFXiUuRMGKi0zV/HREVU6zw9bUt1UEQDATLSowKN0jo3UoA7RVT7P4j9fSfYm/91wSI99s0WFrs6kCAC1AEEFPql1w7Dyd/Iwj3yzRTM3HNIPW1LMLgUAPAZBBT7pgjYNzC6h0o6fZpRcAChCUIFPGtmtSbV/x+FTZ6t0fFpWjqxWQ5+v2q9vkv6wrS+0WqtaGgD4DDrTwie1qoFHPzn5hZU+dnnyMY351zr1axWldYx0CwBO0aICn1Ry8LjqcLTYDM4V9f7S3ZJESAGAchBU4JMsDub/cbdbP1mrRduOVPi4jQfTywwoy3Yeq0pZAOBTCCpAGQbEl90p995/Jzlc//W6g/p+8+FS66ct26NrP1hV5jlX7TnheoEA4OMIKvBZM+69oFLHXdEpRm/f3EO3X9BcU2/qaVv/yFDXhvQ/lpWryd9u1cQZm5VfeL5jbG5BoV5Z8Hu5xzcpYzJGAKhtCCqo9fa8PNz2OSosSB/9pY9G92iqF6/uqkYRwbZtV3V3PAGitcQAbXnFwsnZPzvcrt5zQu2fXOBSPaN7NHW5dgDwdQQV+CzDxQFe/f0sGhjfUJJ0c9+4Utv/OzZBH97eW3FRdRy2qiwqMQpugN/5/jE5eeeCyi0fuz7bckQIL+MBQBH+RkSt9OyoTqofFqSecfUlSf+8o7eSDqQrwcFAcX1bRtk+jxsUrw+X7VFWboFt3YETZ+z2zys436LyyoKdmrXxD1VE8cdFAFDbEVTgswL8nb/5c+eAVnbLYcEBurhdI5fOWzykSFJO/vlgUVBo1UWvLrUtVzSkSPZBBwBqOx79wGf1jKunqLCgav+es/mF2nU0Szd+uFrfbSr9pk9F0aICAOfRogKfFeDvp41PXa6Ms/kKDw6Qv59F3278Q22jI9z6PR8u36MPl++RJK3bX/UB3PIL3Tt7cmZOvvaknVaPuHo1Mr4MALgTQQU+LzI00Pb52l7NTKzENbluevRjtRpq/cQ82/KTIzrqnotau+XcAFBTePQDeBh3Pfp5Y/FOu+UX5+5wy3kBoCYRVAAP467OtO8v3eOW8wCAmQgqgIdxR4uK4eogMgDg4eijAniYirSobEvJ0FuLdynAz6K2MeG6I6GF6tcJUtt/zK/GCgGg5hBUgGp214CW+ibpD2XlFJS5n59Fshr2Q/A7smH/SS3ZkaaHr2inEe/8bFu/YJv07k+7nR7Xu0X9ihUOAB6ARz9ANXtmVGdtfXao/jd+YJn7PXtVZ0nnH/0s+O2IWj4+V68v3Klvks4PHHf9h6v14fI9GvLm8grVkXQgvYKVA4D5CCpADenaLFIzS8zoPKj9udFw1z4xWDF1z82anJaVK0ka+58kSdJ7S3fr4f/+qpW7jtkdu/9EdoVrKGAwOQBehkc/QAU9dHk7vbE42aV9vxmbYLfcv3UDLXjgIh3PylNQgJ/6tTo/j1BR/9e9x86o0Fq6M+zHK/fpjk/XVb5wScdP56lxZEiVzgEANYkWFaCC7h/cVvsSh5daPyDefkLD7s0i1afYhIZFOjSuq4FtG9qFFEnKysm3fR7xzspSx61IPlZqXUVZK/A20PLkY7ro1Z+0es8Jvbk4WYNeXyargwAFANWJoAJUgsVi0dSbetiW/5LQQl/ec4F2vXSlpt7UQy0b1NGr13ev0DlDAv1tn38/kuWuUu0cPnXWpf0OnsjWmH+t06GTZ3X7p2v1zo+7tO/4GY2ZXrUWHQCoKIIKUElX92xq+1zvz2H6A/39dHXPplr2yCC1b1yxOYWCA6r//4716wSWv5Oki187PwN08cdQK3cdlySt3nNCLR+fqxHvrNTZvEL3FgkAxRBUgCqYdHk7tY+J0F8HVn0OnZp4qnI2z74z7dQlyXp/qf0rzSdO55Z5ji9W79ctH6+RJG1LydTX6w66pTbDMAg9AEohqABVMGFwWy188GJFuthSUZb46DA3VCS1jQ63W25cN8TWknI2/3wQyMjO19Qlu/Tawp06lZ1nW9/7xSVlnv/p77fZLT//w3bl5Fc9YLSaPE8dn16gQycr/jYTAN9FUAE8RHx0xR4VObPowYv18jVdNeXartqXOFxrnhis2HqhkuyDSvGOta8t3FnqPBXR4akFVTq+uGs++MVt5wLg/Xg9GfAhTeuFymKx6Nb+ze3Wh/7ZUbf4o5WX5p2fTfnLtQc1Z3OKsnLLHj3X3X7edVy3f7pWd1zQwrYut8CqvAKrgmqgzw4Az0dQAXzEfZe20diL2zjcFhp0LqhMmLHJ6VxCNR1SJOn2T9dKkv695sD5OnIK1O7J+dr+/FDVCeKvKKC2428BwIMEBfi5PCnhJ3/po06xdbXlj1NqHhWmTrF1ne57JCNHUsUmPHQnwzBksVjs1h0vp9Nup6cXSpJ2PD/MFrSq6vvNh1U3NFCD2ke75XwAqh9tq4AH+d/4gbq1f3MN6RgjSfq/S1pr/5QRpfZrGB6kIZ1iFFsvVMO6NCkzpEjSrrTT1VJveYrmK2o1eZ5mbzqslo/PVcvH52rvsdPqU06n3SLzf0utUg3JR7NktRo6fOqsJs7YrLumr5dRgYHvAJiLFhXAg7RvHKGXr+mq/EKrtqdkqkvTSEnSsocv1aWvL5MkDe0co7dv7mlilaUFBfgpIztfO49mqXHdEG04cFL5hVY9NmurbZ8HZm62fb7sDdcnVIwKC7JbTj6apbbR4aVaaIozDEOFVkOPztqibzceliTNm3CRbXvm2QK3vKkFoPpZDC/+T4vMzExFRkYqIyNDdeuW/V+UQG3W8vG5lTpuX+JwtZo8T5I0f+JFys4r1HXTVpXar16dQJ3Kzi+13h2m3tTDNrhe8esoamk6dPLcKLoTh7TV6B5NZbUaav3EvFLn8bPYj1Uz674E9W5ReoqD8lithv695oB6Na+vrs0iK3y8Jxnzr3UK9LfokzF9zS4FtUxF/v0mqAC1QPwT81TgYES5fq2itG7fSYfHFAWB7LwCZecVqmF4sN32v362Xj/+nub+Yh1IenKI6oYGqu0/5tvW7UscLovFUukQJp27xkKroc2H0tU5NtJuGoOSTpzOVXp2ng6dPKu7PltvO16SXl3wuz5YtkdPj+ykuwe2qnQ9xRVdV9KTQ9SgxO++qpYnH1OAn0W3fXKuM/OCBy5Sh8b8HYqaQ1ABYGfVnuO69eO1duvuGtBSz4zqrLlbUjXuq42ljnHUN6Y4q9XQ377YUGNhpTrsnzJC7/20S68vStboHrF2j9T2HDutNXtP6Oa+5171buOglWbBAxdp2FT7CSTL+72VlHE2X5GhpR9DlQxg6/8xRI0iHAeWoski/fycPw7LyS9UcICfbvznaq3fn2637dpeTfXmjT0qVDdQFQQVAKXc/NFqrdl7rvVk9rgB6hFXT9K5/hyLtx9VfHS4re/Ia9d30w194so9Z16BVe2enF/ufmW5e0ArdY6tq9+PZOqJ4R1lsVhUUGhV/D+qdl5X7J8yQu2fnK/cYm9D7Xl5uPz9qtZSs2TSxeUO4Hc0M0fr95/U+K826dqeTfXtpsO69+LWemxYB3259kCpEYCL6i0pO69AnZ5eqNBAf+14YVip7WdyC7Rs5zGHYbS8c5eUfiZP+VaroiNCnO7j6A2vysrJL9S2lAz1iKsv/z9DWGrGWS3eflTX925WodfXDcNQVm6B6oY475t06GS2GkUEl9myBvcgqABwaHdaltKz89W3peO+GTn5hTp0MlttY1wfJdeVf9Bfua6rXcdaSWoeVUcPDGmra3s1c3hMyQBRHV64uouemv2b2887IL6BvrznAq1IPqYVycd0/+C2+mFLikZ2i1VkaKCWJx/TmH9VfCbqrc9eIatVCgywaHtKpnq3qK8OTy2w/Z6CA/yUW2BVr+b1tPHgKT13VWc9M6d04HFk7oSB6hxbus/N2bzCc6+1W6Tuzy2SJE2/q68GtY9Wxtl8/bzruIZ2jlF6dr5emrtdy5KPqW/LKI29pLWum7Za/VpFaea9F1QqvPztiw1avP2oHhnaXpd1iNbJM3m2x1Vto8O1eNIlTo/dsP+k6oYGqt2ff5aL/pzeeWFLPXtVZ0nSsaxcvffTLo2/rK3W7z+pv395LsxNv7OvBnXwrFfY08/kqcBqOG1VK4thGNqemqk2jcKdhrC/f5mkeVuPaO0Tg1WvTqD8LRalnMpRs/qhZbbUVRZBBUCNcSWoVPRxiCR1fXahsnKqNgjdpe0badnOY1U6R2U8fmUHdY6tqzs+LR1GQgL9lJPvngAWGuhvNy2Cu2x4cogahgeXGRYrch1PDO+ge0sMRvi/X1N0/9eb9M4tPdW9WaQue2O55k+8SPmFVk3/Zb9u7BOnG/+5uszzluwQfeDEGZ04k6eYuiEaMOWnP/e5UL2a17N1Ci9ya//m+mqt8wk1f3tuqMKDA7Tv+BmlZpzVzPWH9MjQ9mpWv45L11zSsKkrlFdg1U8PX1rhY4u3MG586nLbm3BWq6E1+04orn4dxUWVrutUdp4iQwP1/eYUPTBzs4Z2jtE/7+hj2/5HerYMQ4qLquP0/8fFg507EVQA1KhhU1fo9yNZ2vz05bbm+EB/i347nKl2jcMVHFDxpvQVycf0l2KtDksmXawhb64o85jP7+6npAPpeufHXdr10pUK9PfTv37ep+d/2K4b+zTTQ1e01/tLd+uL1QfKPA/cL/nFK/XWkmTN3nRYqX8OQOgO258fqo9W7FVYUIDdtBDu8OyoTnr2f9vt1n379wt17QerZLFI+xLPBXDDMLRs5zFbJ+uyvH5Dd13751tsfn4WffrzPlmkUp2wiz9CS804q4TEc8HrgSFtdXPf5vKzSP1e/tHumHsvbq0nhneUJI1+/xf9euiUEq/tqsnf2rdmStLKRwfpoleXuvBbOBf4P7urn0v7usqrgsoHH3yg1157TampqercubOmTp2qiy66qPwDRVABfJ1hGPpl9wm1bxxha/LenXZa9/57g54d1VmNIoJ1+ydr9d3fB6h5A9f/S7cq/U+kc28cHcnM0cz1h3Rz3+YKC/bXRyv26t2fdlfpvCWN7NZEP2yp2oB3jozuEavvN6dU26MvuMeHt/fS2P+U3beopJYN6mj/CffPQF6ZVtGyeE1QmTlzpu644w598MEHGjBggP75z3/qk08+0fbt29W8efNyjyeoAKistXtP6KaP1tiW7xrQUtN/2a8f7h+oZvVD9eDMzbq5X3NFhASobXSEGoYHac6vKWrTKNw2EF9J3236Qw/O/LXCtVzRKUbv39bL9vp14rVdFR0RrMs6RMtisSgtM6fUfz1XxNonBiumbogMw9Cp7HzVLzaIXtFbT0BZam1Q6d+/v3r16qVp06bZ1nXs2FFXX321EhMTyz2eoAKgKv5Iz9ZvhzM1tHOM295UOZaVqx+2pGhQ+2hl5uTrSEaOkg6ka2S3WH276Q8dOpmtG/vEKdDfT2fzC7Xv+Bndd0kbW4dFZ2/NTJn/uz5cvsfhd+5+6UpbH4Z1TwzW2n0n1Tyqjg6fOquE1g3sgklJGdn5uvWTNdqWkulw+/CujdUkMlQtG9TR0cxcDerQSF2aRurp2du0LTVDn47pq00H03VhfEP9c/kedW0aqWFdmqig0KrMnAL1emFxhX5/F7SO0ts399SK5GP6fnOKJl3RTpNnbdXOo1l6bFgHLdh2RBMHx+vuzzZU6LxF2kaHO5xSYtygNvrbRa2Vk2/VBYmVD4W+qLKDI5bFK4JKXl6e6tSpo//+97+65pprbOsnTpyozZs3a/ny0kNs5+bmKjf3/ERmmZmZiouLI6gAqDVy8gsV6O8nP8u5gdu6No10y4BwGdn5Wr33uIZ1aeKGKs87kpGjH7ak6MquTbRq93FtT81UTN0Qjb3kfOfayr7SfCQjR098t1U//Z6m12/orut6NdWJM3m6+aM1eunqLlqy46i+SfpDG5+6vNT5DcPQ5kOnZLFYbK/qF7dq93F9ufagBrZtqKiwILWPidAPW1IUEuivC9s01LytqXpv6W71blFfg9o30tfrDunZqzqrSWSIIkMD1ax+qCwWi921Zebkq9uziyp8nZI0uEO0LBZpyY40LX/kUmXlFOjRb7ZozvgBkqTvNh3WI99scXhst2aRuqxDtBJaN9CJM3m6sktjFVoNZecXKsjfT7n5VtuUEkt3pslqNTS4Y4xOnsnTqew8tW4UXqmay+IVQSUlJUVNmzbVL7/8ogsvvNC2/uWXX9bnn3+unTt3ljrm2Wef1XPPPVdqPUEFAADvUZGgYvrsyY5SrrNkPXnyZGVkZNh+Dh06VBMlAgAAk5g2e3LDhg3l7++vI0eO2K1PS0tTTEyMw2OCg4MVHOzeOS8AAIDnMq1FJSgoSL1799bixfYdrRYvXmz3KAgAANReprWoSNKkSZN0xx13qE+fPkpISNBHH32kgwcPauzYsWaWBQAAPISpQeWmm27SiRMn9Pzzzys1NVVdunTRvHnz1KJFCzPLAgAAHsL0kWmrgnFUAADwPl711g8AAIAzBBUAAOCxCCoAAMBjEVQAAIDHIqgAAACPRVABAAAei6ACAAA8FkEFAAB4LFNHpq2qorHqMjMzTa4EAAC4qujfbVfGnPXqoJKVlSVJiouLM7kSAABQUVlZWYqMjCxzH68eQt9qtSolJUURERGyWCxuPXdmZqbi4uJ06NAhnxyen+vzfr5+jVyfd/P165N8/xqr8/oMw1BWVpZiY2Pl51d2LxSvblHx8/NTs2bNqvU76tat65N/AItwfd7P16+R6/Nuvn59ku9fY3VdX3ktKUXoTAsAADwWQQUAAHgsgooTwcHBeuaZZxQcHGx2KdWC6/N+vn6NXJ938/Xrk3z/Gj3l+ry6My0AAPBttKgAAACPRVABAAAei6ACAAA8FkEFAAB4LIKKAx988IFatWqlkJAQ9e7dWytXrjS7pFKeffZZWSwWu5/GjRvbthuGoWeffVaxsbEKDQ3VpZdeqm3bttmdIzc3V/fff78aNmyosLAwXXXVVfrjjz/s9klPT9cdd9yhyMhIRUZG6o477tCpU6eq5ZpWrFihUaNGKTY2VhaLRbNnz7bbXpPXdPDgQY0aNUphYWFq2LChJkyYoLy8vGq9vjvvvLPUPb3gggu85voSExPVt29fRUREKDo6WldffbV27txpt48330NXrs+b7+G0adPUrVs32+BeCQkJmj9/vm27N987V6/Rm+9fSYmJibJYLHrggQds67z2HhqwM2PGDCMwMND4+OOPje3btxsTJ040wsLCjAMHDphdmp1nnnnG6Ny5s5Gammr7SUtLs22fMmWKERERYcyaNcvYunWrcdNNNxlNmjQxMjMzbfuMHTvWaNq0qbF48WJj48aNxqBBg4zu3bsbBQUFtn2GDRtmdOnSxVi1apWxatUqo0uXLsbIkSOr5ZrmzZtn/OMf/zBmzZplSDK+++47u+01dU0FBQVGly5djEGDBhkbN240Fi9ebMTGxhrjx4+v1usbM2aMMWzYMLt7euLECbt9PPn6hg4dakyfPt347bffjM2bNxsjRowwmjdvbpw+fdq2jzffQ1euz5vv4Zw5c4y5c+caO3fuNHbu3Gk88cQTRmBgoPHbb78ZhuHd987Va/Tm+1fcunXrjJYtWxrdunUzJk6caFvvrfeQoFJCv379jLFjx9qt69Chg/H444+bVJFjzzzzjNG9e3eH26xWq9G4cWNjypQptnU5OTlGZGSk8eGHHxqGYRinTp0yAgMDjRkzZtj2OXz4sOHn52csWLDAMAzD2L59uyHJWLNmjW2f1atXG5KM33//vRqu6ryS/5DX5DXNmzfP8PPzMw4fPmzb5+uvvzaCg4ONjIyMark+wzj3l+To0aOdHuNN12cYhpGWlmZIMpYvX24Yhu/dw5LXZxi+dw/r169vfPLJJz537xxdo2H4xv3Lysoy2rZtayxevNi45JJLbEHFm+8hj36KycvLU1JSkq644gq79VdccYVWrVplUlXO7dq1S7GxsWrVqpVuvvlm7d27V5K0b98+HTlyxO46goODdckll9iuIykpSfn5+Xb7xMbGqkuXLrZ9Vq9ercjISPXv39+2zwUXXKDIyMga/33U5DWtXr1aXbp0UWxsrG2foUOHKjc3V0lJSdV6ncuWLVN0dLTatWunv/3tb0pLS7Nt87bry8jIkCRFRUVJ8r17WPL6ivjCPSwsLNSMGTN05swZJSQk+Ny9c3SNRbz9/o0bN04jRozQkCFD7NZ78z306kkJ3e348eMqLCxUTEyM3fqYmBgdOXLEpKoc69+/v7744gu1a9dOR48e1YsvvqgLL7xQ27Zts9Xq6DoOHDggSTpy5IiCgoJUv379UvsUHX/kyBFFR0eX+u7o6Oga/33U5DUdOXKk1PfUr19fQUFB1XrdV155pW644Qa1aNFC+/bt01NPPaXLLrtMSUlJCg4O9qrrMwxDkyZN0sCBA9WlSxfb9xbVW7J+b7uHjq5P8v57uHXrViUkJCgnJ0fh4eH67rvv1KlTJ9s/QL5w75xdo+T992/GjBnauHGj1q9fX2qbN///j6DigMVisVs2DKPUOrNdeeWVts9du3ZVQkKC2rRpo88//9zW+asy11FyH0f7m/n7qKlrMuO6b7rpJtvnLl26qE+fPmrRooXmzp2ra6+91ulxnnh948eP15YtW/Tzzz+X2uYL99DZ9Xn7PWzfvr02b96sU6dOadasWRozZoyWL1/u9Du98d45u8ZOnTp59f07dOiQJk6cqEWLFikkJMTpft54D3n0U0zDhg3l7+9fKvGlpaWVSoeeJiwsTF27dtWuXbtsb/+UdR2NGzdWXl6e0tPTy9zn6NGjpb7r2LFjNf77qMlraty4canvSU9PV35+fo1ed5MmTdSiRQvt2rXLVpc3XN/999+vOXPmaOnSpWrWrJltva/cQ2fX54i33cOgoCDFx8erT58+SkxMVPfu3fX222/7zL0r6xod8ab7l5SUpLS0NPXu3VsBAQEKCAjQ8uXL9c477yggIMB2Xq+8hxXu1eLj+vXrZ9x333126zp27OhxnWlLysnJMZo2bWo899xztk5Tr7zyim17bm6uw05TM2fOtO2TkpLisNPU2rVrbfusWbPG1M60NXFNRR3BUlJSbPvMmDGj2jvTlnT8+HEjODjY+Pzzz73i+qxWqzFu3DgjNjbWSE5Odrjdm+9hedfniLfdw5Iuu+wyY8yYMV5/71y5Rke86f5lZmYaW7dutfvp06ePcfvttxtbt2716ntIUCmh6PXkTz/91Ni+fbvxwAMPGGFhYcb+/fvNLs3OQw89ZCxbtszYu3evsWbNGmPkyJFGRESErc4pU6YYkZGRxrfffmts3brVuOWWWxy+htasWTNjyZIlxsaNG43LLrvM4Wto3bp1M1avXm2sXr3a6Nq1a7W9npyVlWVs2rTJ2LRpkyHJePPNN41NmzbZXg2vqWsqerVu8ODBxsaNG40lS5YYzZo1q/Krg2VdX1ZWlvHQQw8Zq1atMvbt22csXbrUSEhIMJo2beo113ffffcZkZGRxrJly+xe78zOzrbt4833sLzr8/Z7OHnyZGPFihXGvn37jC1bthhPPPGE4efnZyxatMgwDO++d65co7ffP0eKv/VjGN57DwkqDrz//vtGixYtjKCgIKNXr152rx96iqL33wMDA43Y2Fjj2muvNbZt22bbbrVajWeeecZo3LixERwcbFx88cXG1q1b7c5x9uxZY/z48UZUVJQRGhpqjBw50jh48KDdPidOnDBuu+02IyIiwoiIiDBuu+02Iz09vVquaenSpYakUj9F/7VTk9d04MABY8SIEUZoaKgRFRVljB8/3sjJyam268vOzjauuOIKo1GjRkZgYKDRvHlzY8yYMaVq9+Trc3Rtkozp06fb9vHme1je9Xn7Pbz77rttf+81atTIGDx4sC2kGIZ33ztXrtHb758jJYOKt95Di2EYRsUfGAEAAFQ/OtMCAACPRVABAAAei6ACAAA8FkEFAAB4LIIKAADwWAQVAADgsQgqAADAYxFUAACAxyKoAKiSli1baurUqS7vv2zZMlksFp06daraagLgOxiZFqhlLr30UvXo0aNC4aIsx44dU1hYmOrUqePS/nl5eTp58qRiYmIqPaV9VS1btkyDBg1Senq66tWrZ0oNAFwTYHYBADyPYRgqLCxUQED5f0U0atSoQucOCgpS48aNK1sagFqGRz9ALXLnnXdq+fLlevvtt2WxWGSxWLR//37b45iFCxeqT58+Cg4O1sqVK7Vnzx6NHj1aMTExCg8PV9++fbVkyRK7c5Z89GOxWPTJJ5/ommuuUZ06ddS2bVvNmTPHtr3ko5/PPvtM9erV08KFC9WxY0eFh4dr2LBhSk1NtR1TUFCgCRMmqF69emrQoIEee+wxjRkzRldffbXTaz1w4IBGjRql+vXrKywsTJ07d9a8efO0f/9+DRo0SJJUv359WSwW3XnnnZLOBbRXX31VrVu3VmhoqLp3765vvvmmVO1z585V9+7dFRISov79+2vr1q2VvCMAykNQAWqRt99+WwkJCfrb3/6m1NRUpaamKi4uzrb90UcfVWJionbs2KFu3brp9OnTGj58uJYsWaJNmzZp6NChGjVqlA4ePFjm9zz33HO68cYbtWXLFg0fPly33XabTp486XT/7Oxsvf766/r3v/+tFStW6ODBg3r44Ydt21955RV9+eWXmj59un755RdlZmZq9uzZZdYwbtw45ebmasWKFdq6dateeeUVhYeHKy4uTrNmzZIk7dy5U6mpqXr77bclSU8++aSmT5+uadOmadu2bXrwwQd1++23a/ny5XbnfuSRR/T6669r/fr1io6O1lVXXaX8/Pwy6wFQSZWacxmA1yo59bthGMbSpUsNScbs2bPLPb5Tp07Gu+++a1tu0aKF8dZbb9mWJRlPPvmkbfn06dOGxWIx5s+fb/ddRdPCT58+3ZBk7N6923bM+++/b8TExNiWY2JijNdee822XFBQYDRv3twYPXq00zq7du1qPPvssw63layhqM6QkBBj1apVdvv+9a9/NW655Ra742bMmGHbfuLECSM0NNSYOXOm01oAVB59VADY9OnTx275zJkzeu655/TDDz8oJSVFBQUFOnv2bLktKt26dbN9DgsLU0REhNLS0pzuX6dOHbVp08a23KRJE9v+GRkZOnr0qPr162fb7u/vr969e8tqtTo954QJE3Tfffdp0aJFGjJkiK677jq7ukravn27cnJydPnll9utz8vLU8+ePe3WJSQk2D5HRUWpffv22rFjh9NzA6g8ggoAm7CwMLvlRx55RAsXLtTrr7+u+Ph4hYaG6vrrr1deXl6Z5wkMDLRbtlgsZYYKR/sbJV5ILPmGUMntJd1zzz0aOnSo5s6dq0WLFikxMVFvvPGG7r//fof7F9U3d+5cNW3a1G5bcHBwmd/lqD4A7kEfFaCWCQoKUmFhoUv7rly5UnfeeaeuueYade3aVY0bN9b+/furt8ASIiMjFRMTo3Xr1tnWFRYWatOmTeUeGxcXp7Fjx+rbb7/VQw89pI8//ljSud9B0XmKdOrUScHBwTp48KDi4+Ptfor345GkNWvW2D6np6crOTlZHTp0qNJ1AnCMFhWglmnZsqXWrl2r/fv3Kzw8XFFRUU73jY+P17fffqtRo0bJYrHoqaeeKrNlpLrcf//9SkxMVHx8vDp06KB3331X6enpZbZiPPDAA7ryyivVrl07paen66efflLHjh0lSS1atJDFYtEPP/yg4cOHKzQ0VBEREXr44Yf14IMPymq1auDAgcrMzNSqVasUHh6uMWPG2M79/PPPq0GDBoqJidE//vEPNWzYsMw3kABUHi0qQC3z8MMPy9/fX506dVKjRo3K7G/y1ltvqX79+rrwwgs1atQoDR06VL169arBas957LHHdMstt+gvf/mLEhISFB4erqFDhyokJMTpMYWFhRo3bpw6duyoYcOGqX379vrggw8kSU2bNtVzzz2nxx9/XDExMRo/frwk6YUXXtDTTz+txMREdezYUUOHDtX//vc/tWrVyu7cU6ZM0cSJE9W7d2+lpqZqzpw5tlYaAO7FyLQAvI7ValXHjh1144036oUXXqix72VEW6Dm8egHgMc7cOCAFi1apEsuuUS5ubl67733tG/fPt16661mlwagmvHoB4DH8/Pz02effaa+fftqwIAB2rp1q5YsWWLrcwLAd/HoBwAAeCxaVAAAgMciqAAAAI9FUAEAAB6LoAIAADwWQQUAAHgsggoAAPBYBBUAAOCxCCoAAMBj/X/I6iLGN6JolgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "dee25189",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 透 明 国 度 马 克 笔 动 漫 人 物 插 画 手 绘 技 法\n",
      "target： a transparent country , mark's manipulator's intuitive painting technique .\n",
      "pred： a transparent country , mark's manipulator's intuitive painting technique .\n",
      "\n",
      "input： 全 国 计 算 机 等 级 考 试 模 拟 考 场 二 级 c 语 言\n",
      "target： the national computer level examination simulator , level 2c .\n",
      "pred： the national computer level examination simulator , level 2c .\n",
      "\n",
      "input： 水 彩 色 铅 笔 之 旅\n",
      "target： water-coloured pencil trip .\n",
      "pred： water-coloured pencil trip , and it's beautiful .\n",
      "\n",
      "input： 大 学 生 职 业 规 划 与 就 业 指 导 教 程\n",
      "target： vocational planning and employment guidance for university students\n",
      "pred： vocational planning and employment guidance for university students\n",
      "\n",
      "input： 原 画 梦 3 0 天 学 会 日 系 插 画 第 2 版\n",
      "target： original dream , 30 days of academies day illustration , 2nd edition .\n",
      "pred： original dream , 30 days of academies day illustration , version 2 .\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "11053694",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 国 之 重 器 出 版 工 程 高 光 谱 卫 星 图 像 协 同 处 理 理 论 与 方 法\n",
      "target： theorems and methods of co-processing of hyperspectral satellite images .\n",
      "pred： theorems and methods of co-processing of hyperspectral satellite images .\n",
      "\n",
      "input： 名 画 的 诞 生 给 孩 子 的 艺 术 长 卷 梵 高 在 画 画\n",
      "target： the birth of the famous painting , the roll of art for the children , the van gogh is painting .\n",
      "pred： the birth of the famous painting , the roll of art for the children , the van gogh is painting .\n",
      "\n",
      "input： 乐 理 自 学 让 你 轻 松 掌 握 简 谱\n",
      "target： music is self-learning , so that you can easily grasp the brevity .\n",
      "pred： music is self-learning , so you can easily master the five-line spectra .\n",
      "\n",
      "input： v u e . j s 前 端 开 发 技 术\n",
      "target： vue .js front end development technology\n",
      "pred： vue .js front end development technology majority\n",
      "\n",
      "input： r p a 落 地 指 南\n",
      "target： rpa landing guide\n",
      "pred： rpa landing learning guide case guide\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da06c77f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
